import numpy as np
from NN_BP import *

def load_mnist(file_data,file_lab):
    #加载训练数据
    data = np.load(file_data)
    lab=np.load(file_data)
    N,D=np.shape(data)
    
    #构造one-hot标签
    lab_onehot=np.zeros([N,10])
    for i in range(N):
        id =int(lab[i,0])
        lab_onehot[i,id]=1
    data=data.astype(np.float)/255.0
    return data,lab_onehot